{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Mixed precision"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/mixed_precision\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/mixed_precision.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/mixed_precision.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/mixed_precision.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory. By keeping certain parts of the model in the 32-bit types for numeric stability, the model will have a lower step time and train equally as well in terms of the evaluation metrics such as accuracy. This guide describes how to use the Keras mixed precision API to speed up your models. Using this API can improve performance by more than 3 times on modern GPUs and 60% on TPUs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3vsYi_bv7gS_"
      },
      "source": [
        "Today, most models use the float32 dtype, which takes 32 bits of memory. However, there are two lower-precision dtypes, float16 and bfloat16, each which take 16 bits of memory instead. Modern accelerators can run operations faster in the 16-bit dtypes, as they have specialized hardware to run 16-bit computations and 16-bit dtypes can be read from memory faster.\n",
        "\n",
        "NVIDIA GPUs can run operations in float16 faster than in float32, and TPUs can run operations in bfloat16 faster than float32. Therefore, these lower-precision dtypes should be used whenever possible on those devices. However, variables and a few computations should still be in float32 for numeric reasons so that the model trains to the same quality. The Keras mixed precision API allows you to use a mix of either float16 or bfloat16 with float32, to get the performance benefits from float16/bfloat16 and the numeric stability benefits from float32.\n",
        "\n",
        "Note: In this guide, the term \"numeric stability\" refers to how a model's quality is affected by the use of a lower-precision dtype instead of a higher precision dtype. An operation is \"numerically unstable\" in float16 or bfloat16 if running it in one of those dtypes causes the model to have worse evaluation accuracy or other metrics compared to running the operation in float32."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IqR2PQG4ZaZ0"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras import mixed_precision"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "814VXqdh8Q0r"
      },
      "source": [
        "## Supported hardware\n",
        "\n",
        "While mixed precision will run on most hardware, it will only speed up models on recent NVIDIA GPUs and Cloud TPUs. NVIDIA GPUs support using a mix of float16 and float32, while TPUs support a mix of bfloat16 and float32.\n",
        "\n",
        "Among NVIDIA GPUs, those with compute capability 7.0 or higher will see the greatest performance benefit from mixed precision because they have special hardware units, called Tensor Cores, to accelerate float16 matrix multiplications and convolutions. Older GPUs offer no math performance benefit for using mixed precision, however memory and bandwidth savings can enable some speedups. You can look up the compute capability for your GPU at NVIDIA's [CUDA GPU web page](https://developer.nvidia.com/cuda-gpus). Examples of GPUs that will benefit most from mixed precision include RTX GPUs, the V100, and the A100."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-q2hisD60F0_"
      },
      "source": [
        "Note: If running this guide in Google Colab, the GPU runtime typically has a P100 connected. The P100 has compute capability 6.0 and is not expected to show a significant speedup.\n",
        "\n",
        "You can check your GPU type with the following. The command only exists if the\n",
        "NVIDIA drivers are installed, so the following will raise an error otherwise."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j-Yzg_lfkoa_"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi -L"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hu_pvZDN0El3"
      },
      "source": [
        "All Cloud TPUs support bfloat16.\n",
        "\n",
        "Even on CPUs and older GPUs, where no speedup is expected, mixed precision APIs can still be used for unit testing, debugging, or just to try out the API. On CPUs, mixed precision will run significantly slower, however."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HNOmvumB-orT"
      },
      "source": [
        "## Setting the dtype policy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "54ecYY2Hn16E"
      },
      "source": [
        "To use mixed precision in Keras, you need to create a `tf.keras.mixed_precision.Policy`, typically referred to as a *dtype policy*. Dtype policies specify the dtypes layers will run in. In this guide, you will construct a policy from the string `'mixed_float16'` and set it as the global policy. This will cause subsequently created layers to use mixed precision with a mix of float16 and float32."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x3kElPVH-siO"
      },
      "outputs": [],
      "source": [
        "policy = mixed_precision.Policy('mixed_float16')\n",
        "mixed_precision.set_global_policy(policy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6ids1rT_UM5q"
      },
      "source": [
        "For short, you can directly pass a string to `set_global_policy`, which is typically done in practice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6a8iNFoBUSqR"
      },
      "outputs": [],
      "source": [
        "# Equivalent to the two lines above\n",
        "mixed_precision.set_global_policy('mixed_float16')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oGAMaa0Ho3yk"
      },
      "source": [
        "The policy specifies two important aspects of a layer: the dtype the layer's computations are done in, and the dtype of a layer's variables. Above, you created a `mixed_float16` policy (i.e., a `mixed_precision.Policy` created by passing the string `'mixed_float16'` to its constructor). With this policy, layers use float16 computations and float32 variables. Computations are done in float16 for performance, but variables must be kept in float32 for numeric stability. You can directly query these properties of the policy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GQRbYm4f8p-k"
      },
      "outputs": [],
      "source": [
        "print('Compute dtype: %s' % policy.compute_dtype)\n",
        "print('Variable dtype: %s' % policy.variable_dtype)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MOFEcna28o4T"
      },
      "source": [
        "As mentioned before, the `mixed_float16` policy will most significantly improve performance on NVIDIA GPUs with compute capability of at least 7.0. The policy will run on other GPUs and CPUs but may not improve performance. For TPUs, the `mixed_bfloat16` policy should be used instead."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cAHpt128tVpK"
      },
      "source": [
        "## Building the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nB6ujaR8qMAy"
      },
      "source": [
        "Next, let's start building a simple model. Very small toy models typically do not benefit from mixed precision, because overhead from the TensorFlow runtime  typically dominates the execution time, making any performance improvement on the GPU negligible. Therefore, let's build two large `Dense` layers with 4096 units each if a GPU is used."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0DQM24hL_14Q"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(784,), name='digits')\n",
        "if tf.config.list_physical_devices('GPU'):\n",
        "  print('The model will run with 4096 units on a GPU')\n",
        "  num_units = 4096\n",
        "else:\n",
        "  # Use fewer units on CPUs so the model finishes in a reasonable amount of time\n",
        "  print('The model will run with 64 units on a CPU')\n",
        "  num_units = 64\n",
        "dense1 = layers.Dense(num_units, activation='relu', name='dense_1')\n",
        "x = dense1(inputs)\n",
        "dense2 = layers.Dense(num_units, activation='relu', name='dense_2')\n",
        "x = dense2(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2dezdcqnOXHk"
      },
      "source": [
        "Each layer has a policy and uses the global policy by default. Each of the `Dense` layers therefore have the `mixed_float16` policy because you set the global policy to `mixed_float16` previously. This will cause the dense layers to do float16 computations and have float32 variables. They cast their inputs to float16 in order to do float16 computations, which causes their outputs to be float16 as a result. Their variables are float32 and will be cast to float16 when the layers are called to avoid errors from dtype mismatches."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kC58MzP4PEcC"
      },
      "outputs": [],
      "source": [
        "print(dense1.dtype_policy)\n",
        "print('x.dtype: %s' % x.dtype.name)\n",
        "# 'kernel' is dense1's variable\n",
        "print('dense1.kernel.dtype: %s' % dense1.kernel.dtype.name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_WAZeqDyqZcb"
      },
      "source": [
        "Next, create the output predictions. Normally, you can create the output predictions as follows, but this is not always numerically stable with float16."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ybBq1JDwNIbz"
      },
      "outputs": [],
      "source": [
        "# INCORRECT: softmax and model output will be float16, when it should be float32\n",
        "outputs = layers.Dense(10, activation='softmax', name='predictions')(x)\n",
        "print('Outputs dtype: %s' % outputs.dtype.name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D0gSWxc9NN7q"
      },
      "source": [
        "A softmax activation at the end of the model should be float32. Because the dtype policy is `mixed_float16`, the softmax activation would normally have a float16 compute dtype and output float16 tensors.\n",
        "\n",
        "This can be fixed by separating the Dense and softmax layers, and by passing `dtype='float32'` to the softmax layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IGqCGn4BsODw"
      },
      "outputs": [],
      "source": [
        "# CORRECT: softmax and model output are float32\n",
        "x = layers.Dense(10, name='dense_logits')(x)\n",
        "outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)\n",
        "print('Outputs dtype: %s' % outputs.dtype.name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tUdkY_DHsP8i"
      },
      "source": [
        "Passing `dtype='float32'` to the softmax layer constructor overrides the layer's dtype policy to be the `float32` policy, which does computations and keeps variables in float32. Equivalently, you could have instead passed `dtype=mixed_precision.Policy('float32')`; layers always convert the dtype argument to a policy. Because the `Activation` layer has no variables, the policy's variable dtype is ignored, but the policy's compute dtype of float32 causes softmax and the model output to be float32. \n",
        "\n",
        "\n",
        "Adding a float16 softmax in the middle of a model is fine, but a softmax at the end of the model should be in float32. The reason is that if the intermediate tensor flowing from the softmax to the loss is float16 or bfloat16, numeric issues may occur.\n",
        "\n",
        "You can override the dtype of any layer to be float32 by passing `dtype='float32'` if you think it will not be numerically stable with float16 computations. But typically, this is only necessary on the last layer of the model, as most layers have sufficient precision with `mixed_float16` and `mixed_bfloat16`.\n",
        "\n",
        "Even if the model does not end in a softmax, the outputs should still be float32. While unnecessary for this specific model, the model outputs can be cast to float32 with the following:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzVAoLI56jR8"
      },
      "outputs": [],
      "source": [
        "# The linear activation is an identity function. So this simply casts 'outputs'\n",
        "# to float32. In this particular case, 'outputs' is already float32 so this is a\n",
        "# no-op.\n",
        "outputs = layers.Activation('linear', dtype='float32')(outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tpY4ZP7us5hA"
      },
      "source": [
        "Next, finish and compile the model, and generate input data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g4OT3Z6kqYAL"
      },
      "outputs": [],
      "source": [
        "model = keras.Model(inputs=inputs, outputs=outputs)\n",
        "model.compile(loss='sparse_categorical_crossentropy',\n",
        "              optimizer=keras.optimizers.RMSprop(),\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
        "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
        "x_test = x_test.reshape(10000, 784).astype('float32') / 255"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Sm8FJHegVRN"
      },
      "source": [
        "This example cast the input data from int8 to float32. You don't cast to float16 since the division by 255 is on the CPU, which runs float16 operations slower than float32 operations. In this case, the performance difference in negligible, but in general you should run input processing math in float32 if it runs on the CPU. The first layer of the model will cast the inputs to float16, as each layer casts floating-point inputs to its compute dtype.\n",
        "\n",
        "The initial weights of the model are retrieved. This will allow training from scratch again by loading the weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0UYs-u_DgiA5"
      },
      "outputs": [],
      "source": [
        "initial_weights = model.get_weights()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zlqz6eVKs9aU"
      },
      "source": [
        "## Training the model with Model.fit\n",
        "\n",
        "Next, train the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hxI7-0ewmC0A"
      },
      "outputs": [],
      "source": [
        "history = model.fit(x_train, y_train,\n",
        "                    batch_size=8192,\n",
        "                    epochs=5,\n",
        "                    validation_split=0.2)\n",
        "test_scores = model.evaluate(x_test, y_test, verbose=2)\n",
        "print('Test loss:', test_scores[0])\n",
        "print('Test accuracy:', test_scores[1])\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MPhJ9OPWt4x5"
      },
      "source": [
        "Notice the model prints the time per step in the logs: for example, \"25ms/step\". The first epoch may be slower as TensorFlow spends some time optimizing the model, but afterwards the time per step should stabilize. \n",
        " \n",
        "If you are running this guide in Colab, you can compare the performance of mixed precision with float32. To do so, change the policy from `mixed_float16` to `float32` in the \"Setting the dtype policy\" section, then rerun all the cells up to this point. On GPUs with compute capability 7.X, you should see the time per step significantly increase, indicating mixed precision sped up the model.  Make sure to change the policy back to `mixed_float16` and rerun the cells before continuing with the guide.\n",
        "\n",
        "On GPUs with compute capability of at least 8.0 (Ampere GPUs and above), you likely will see no performance improvement in the toy model in this guide when using mixed precision compared to float32. This is due to the use of [TensorFloat-32](https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_tensor_float_32_execution), which automatically uses lower precision math in certain float32 ops such as `tf.linalg.matmul`. TensorFloat-32 gives some of the performance advantages of mixed precision when using float32. However, in real-world models, you will still typically see significantly performance improvements from mixed precision due to memory bandwidth savings and ops which TensorFloat-32 does not support.\n",
        "\n",
        "If running mixed precision on a TPU, you will not see as much of a performance gain compared to running mixed precision on GPUs, especially pre-Ampere GPUs. This is because TPUs do certain ops in bfloat16 under the hood even with the default dtype policy of float32. This is similar to how Ampere GPUs use TensorFloat-32 by default. Compared to Ampere GPUs, TPUs typically see less performance gains with mixed precision on real-world models.\n",
        "\n",
        "For many real-world models, mixed precision also allows you to double the batch size without running out of memory, as float16 tensors take half the memory. This does not apply however to this toy model, as you can likely run the model in any dtype where each batch consists of the entire MNIST dataset of 60,000 images."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mNKMXlCvHgHb"
      },
      "source": [
        "## Loss scaling\n",
        "\n",
        "Loss scaling is a technique which `tf.keras.Model.fit` automatically performs with the `mixed_float16` policy to avoid numeric underflow. This section describes what loss scaling is and the next section describes how to use it with a custom training loop."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1xQX62t2ow0g"
      },
      "source": [
        "### Underflow and Overflow\n",
        "\n",
        "The float16 data type has a narrow dynamic range compared to float32. This means values above $65504$ will overflow to infinity and values below $6.0 \\times 10^{-8}$ will underflow to zero. float32 and bfloat16 have a much higher dynamic range so that overflow and underflow are not a problem.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CHmXRb-yRWbE"
      },
      "outputs": [],
      "source": [
        "x = tf.constant(256, dtype='float16')\n",
        "(x ** 2).numpy()  # Overflow"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5unZLhN0RfQM"
      },
      "outputs": [],
      "source": [
        "x = tf.constant(1e-5, dtype='float16')\n",
        "(x ** 2).numpy()  # Underflow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pUIbhQypRVe_"
      },
      "source": [
        "In practice, overflow with float16 rarely occurs. Additionally, underflow also rarely occurs during the forward pass. However, during the backward pass, gradients can underflow to zero. Loss scaling is a technique to prevent this underflow."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FAL5qij_oNqJ"
      },
      "source": [
        "### Loss scaling overview\n",
        "\n",
        "The basic concept of loss scaling is simple: simply multiply the loss by some large number, say $1024$, and you get the *loss scale* value. This will cause the gradients to scale by $1024$ as well, greatly reducing the chance of underflow. Once the final gradients are computed, divide them by $1024$ to bring them back to their correct values.\n",
        "\n",
        "The pseudocode for this process is:\n",
        "\n",
        "```\n",
        "loss_scale = 1024\n",
        "loss = model(inputs)\n",
        "loss *= loss_scale\n",
        "# Assume `grads` are float32. You do not want to divide float16 gradients.\n",
        "grads = compute_gradient(loss, model.trainable_variables)\n",
        "grads /= loss_scale\n",
        "```\n",
        "\n",
        "Choosing a loss scale can be tricky. If the loss scale is too low, gradients may still underflow to zero. If too high, the opposite the problem occurs: the gradients may overflow to infinity.\n",
        "\n",
        "To solve this, TensorFlow dynamically determines the loss scale so you do not have to choose one manually. If you use `tf.keras.Model.fit`, loss scaling is done for you so you do not have to do any extra work. If you use a custom training loop, you must explicitly use the special optimizer wrapper `tf.keras.mixed_precision.LossScaleOptimizer` in order to use loss scaling. This is described in the next section.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yqzbn8Ks9Q98"
      },
      "source": [
        "## Training the model with a custom training loop"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CRANRZZ69nA7"
      },
      "source": [
        "So far, you have trained a Keras model with mixed precision using `tf.keras.Model.fit`. Next, you will use mixed precision with a custom training loop. If you do not already know what a custom training loop is, please read the [Custom training guide](../tutorials/customization/custom_training_walkthrough.ipynb) first."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wXTaM8EEyEuo"
      },
      "source": [
        "Running a custom training loop with mixed precision requires two changes over running it in float32:\n",
        "\n",
        "1. Build the model with mixed precision (you already did this)\n",
        "2. Explicitly use loss scaling if `mixed_float16` is used.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M2zpp7_65mTZ"
      },
      "source": [
        "For step (2), you will use the `tf.keras.mixed_precision.LossScaleOptimizer` class, which wraps an optimizer and applies loss scaling. By default, it dynamically determines the loss scale so you do not have to choose one. Construct a `LossScaleOptimizer` as follows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ogZN3rIH0vpj"
      },
      "outputs": [],
      "source": [
        "optimizer = keras.optimizers.RMSprop()\n",
        "optimizer = mixed_precision.LossScaleOptimizer(optimizer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FVy5gnBqTE9z"
      },
      "source": [
        "If you want, it is possible choose an explicit loss scale or otherwise customize the loss scaling behavior, but it is highly recommended to keep the default loss scaling behavior, as it has been found to work well on all known models. See the `tf.keras.mixed_precision.LossScaleOptimizer` documention if you want to customize the loss scaling behavior."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JZYEr5hA3MXZ"
      },
      "source": [
        "Next, define the loss object and the `tf.data.Dataset`s:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9cE7Mm533hxe"
      },
      "outputs": [],
      "source": [
        "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
        "train_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "                 .shuffle(10000).batch(8192))\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(8192)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4W0zxrxC3nww"
      },
      "source": [
        "Next, define the training step function. You will use two new methods from the loss scale optimizer to scale the loss and unscale the gradients:\n",
        "\n",
        "- `get_scaled_loss(loss)`: Multiplies the loss by the loss scale\n",
        "- `get_unscaled_gradients(gradients)`: Takes in a list of scaled gradients as inputs, and divides each one by the loss scale to unscale them\n",
        "\n",
        "These functions must be used in order to prevent underflow in the gradients. `LossScaleOptimizer.apply_gradients` will then apply gradients if none of them have `Inf`s or `NaN`s. It will also update the loss scale, halving it if the gradients had `Inf`s or `NaN`s and potentially increasing it otherwise."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V0vHlust4Rug"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(x, y):\n",
        "  with tf.GradientTape() as tape:\n",
        "    predictions = model(x)\n",
        "    loss = loss_object(y, predictions)\n",
        "    scaled_loss = optimizer.get_scaled_loss(loss)\n",
        "  scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables)\n",
        "  gradients = optimizer.get_unscaled_gradients(scaled_gradients)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "  return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rcFxEjia6YPQ"
      },
      "source": [
        "The `LossScaleOptimizer` will likely skip the first few steps at the start of training. The loss scale starts out high so that the optimal loss scale can quickly be determined. After a few steps, the loss scale will stabilize and very few steps will be skipped. This process happens automatically and does not affect training quality."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IHIvKKhg4Y-G"
      },
      "source": [
        "Now, define the test step:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nyk_xiZf42Tt"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def test_step(x):\n",
        "  return model(x, training=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBs98MZyhBOB"
      },
      "source": [
        "Load the initial weights of the model, so you can retrain from scratch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jpzOe3WEhFUJ"
      },
      "outputs": [],
      "source": [
        "model.set_weights(initial_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s9Pi1ADM47Ud"
      },
      "source": [
        "Finally, run the custom training loop:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N274tJ3e4_6t"
      },
      "outputs": [],
      "source": [
        "for epoch in range(5):\n",
        "  epoch_loss_avg = tf.keras.metrics.Mean()\n",
        "  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='test_accuracy')\n",
        "  for x, y in train_dataset:\n",
        "    loss = train_step(x, y)\n",
        "    epoch_loss_avg(loss)\n",
        "  for x, y in test_dataset:\n",
        "    predictions = test_step(x)\n",
        "    test_accuracy.update_state(y, predictions)\n",
        "  print('Epoch {}: loss={}, test accuracy={}'.format(epoch, epoch_loss_avg.result(), test_accuracy.result()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d7daQKGerOFE"
      },
      "source": [
        "## GPU performance tips\n",
        "\n",
        "Here are some performance tips when using mixed precision on GPUs.\n",
        "\n",
        "### Increasing your batch size\n",
        "\n",
        "If it doesn't affect model quality, try running with double the batch size when using mixed precision. As float16 tensors use half the memory, this often allows you to double your batch size without running out of memory. Increasing batch size typically increases training throughput, i.e. the training elements per second your model can run on.\n",
        "\n",
        "### Ensuring GPU Tensor Cores are used\n",
        "\n",
        "As mentioned previously, modern NVIDIA GPUs use a special hardware unit called Tensor Cores that can multiply float16 matrices very quickly. However, Tensor Cores requires certain dimensions of tensors to be a multiple of 8. In the examples below, an argument is bold if and only if it needs to be a multiple of 8 for Tensor Cores to be used.\n",
        "\n",
        "- tf.keras.layers.Dense(**units=64**)\n",
        "- tf.keras.layers.Conv2d(**filters=48**, kernel_size=7, stride=3)\n",
        "  - And similarly for other convolutional layers, such as tf.keras.layers.Conv3d\n",
        "- tf.keras.layers.LSTM(**units=64**)\n",
        "  - And similar for other RNNs, such as tf.keras.layers.GRU\n",
        "- tf.keras.Model.fit(epochs=2, **batch_size=128**)\n",
        "\n",
        "You should try to use Tensor Cores when possible. If you want to learn more, [NVIDIA deep learning performance guide](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html) describes the exact requirements for using Tensor Cores as well as other Tensor Core-related performance information.\n",
        "\n",
        "### XLA\n",
        "\n",
        "XLA is a compiler that can further increase mixed precision performance, as well as float32 performance to a lesser extent. Refer to the [XLA guide](https://www.tensorflow.org/xla) for details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2tFDX8fm6o_3"
      },
      "source": [
        "## Cloud TPU performance tips\n",
        "\n",
        "As with GPUs, you should try doubling your batch size when using Cloud TPUs because bfloat16 tensors use half the memory. Doubling batch size may increase training throughput.\n",
        "\n",
        "TPUs do not require any other mixed precision-specific tuning to get optimal performance. They already require the use of XLA. TPUs benefit from having certain dimensions being multiples of $128$, but this applies equally to the float32 type as it does for mixed precision. Check the [Cloud TPU performance guide](https://cloud.google.com/tpu/docs/performance-guide) for general TPU performance tips, which apply to mixed precision as well as float32 tensors."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "--wSEU91wO9w"
      },
      "source": [
        "## Summary\n",
        "\n",
        "- You should use mixed precision if you use TPUs or NVIDIA GPUs with at least compute capability 7.0, as it will improve performance by up to 3x.\n",
        "- You can use mixed precision with the following lines:\n",
        "\n",
        "  ```python\n",
        "  # On TPUs, use 'mixed_bfloat16' instead\n",
        "  mixed_precision.set_global_policy('mixed_float16')\n",
        "  ```\n",
        "\n",
        "* If your model ends in softmax, make sure it is float32. And regardless of what your model ends in, make sure the output is float32.\n",
        "* If you use a custom training loop with `mixed_float16`, in addition to the above lines, you need to wrap your optimizer with a `tf.keras.mixed_precision.LossScaleOptimizer`. Then call `optimizer.get_scaled_loss` to scale the loss, and `optimizer.get_unscaled_gradients` to unscale the gradients.\n",
        "* Double the training batch size if it does not reduce evaluation accuracy\n",
        "* On GPUs, ensure most tensor dimensions are a multiple of $8$ to maximize performance\n",
        "\n",
        "For more examples of mixed precision using the `tf.keras.mixed_precision` API, check the [official models repository](https://github.com/tensorflow/models/tree/master/official). Most official models, such as [ResNet](https://github.com/tensorflow/models/tree/master/official/vision/image_classification) and [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/transformer), will run using mixed precision by passing `--dtype=fp16`.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "mixed_precision.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
